package com.atguigu.realtime.apps

import java.sql.{Connection, PreparedStatement, ResultSet}

import com.alibaba.fastjson.JSON
import com.atguigu.common.constants.TopicConstant
import com.atguigu.realtime.apps.TestApp.{appName, batchDuration, context, groupId, topic}
import com.atguigu.realtime.beans.OrderInfo
import com.atguigu.realtime.utils.{DStreamUtil, DateHandleUtil, JDBCUtil}
import org.apache.kafka.clients.consumer.ConsumerRecord
import org.apache.kafka.common.TopicPartition
import org.apache.spark.streaming.dstream.InputDStream
import org.apache.spark.streaming.kafka010.{HasOffsetRanges, OffsetRange}
import org.apache.spark.streaming.{Seconds, StreamingContext}

import scala.collection.mutable

/**
 * Created by Smexy on 2022/6/29
 *
 *
 *    两种处理方式:
 *        第一种： SparkStreaming只负责将Order_info的明细精确一次消费，幂等输出到数据库中。
 *                  例如，幂等输出到 es，hbase，mysql
 *
 *                在Mysql或es或hbase内部，使用数据库提供的方法对GMV进行计算。
 *
 *                完全可以省略SparkStreaming这个环节，直接使用Canal向数据库同步数据。
 *                在Canal端幂等输出。
 *
 *
 *         第二种：  在SparkStreaming中聚合。 将聚合的结果写入数据库。
 *                    在实时数仓中，使用的比较少。一般用于计算某个特定的需求。
 *
 *                    无法实现幂等输出。类似wordcount。 是全局累加的场景。
 *                    只能借助事务输出。
 */
object GMVApp extends BaseApp {
  override var appName: String = "GMVApp"
  override var groupId: String = "realtime220212"
  override var topic: String = TopicConstant.ORDER_INFO
  override var batchDuration: Int = 10

  def selectHistoryOffsetsFromMysql(groupId:String,topic:String ):Map[TopicPartition, Long]={

    val offests = new mutable.HashMap[TopicPartition, Long]()

    val sql=
      """
        |
        |select
        |   partitionId,offset
        |from offsets
        |where groupId=? and topic=?
        |
        |""".stripMargin

    var connection: Connection = null
    var ps: PreparedStatement =  null
    try {
      connection = JDBCUtil.getConnection()
      ps= connection.prepareStatement(sql)
      ps.setString(1, groupId)
      ps.setString(2, topic)

      val resultSet: ResultSet = ps.executeQuery()

      while (resultSet.next()) {

        offests.put(new TopicPartition(topic, resultSet.getInt("partitionId")), resultSet.getLong("offset"))

      }
    } catch {
      case e:Exception => {
        e.printStackTrace()
        throw new RuntimeException("查询偏移量失败!")
      }
    } finally {

      if (ps != null){
        ps.close()
      }

      if (connection != null){
        connection.close()
      }

    }

    //把可变的map转不可变
    offests.toMap

  }


  def writeResultAndOffsetsInCommonTranscation(result: Array[((String, String), Double)], ranges: Array[OffsetRange]): Unit = {

    val sql1=
      """
        |
        |INSERT INTO gmvstats VALUES(?,?,?)
        |ON DUPLICATE KEY UPDATE gmv = values(gmv) + gmv
        |
        |
        |""".stripMargin

    val sql2=
      """
        |
        |INSERT INTO offsets VALUES(?,?,?,?)
        |ON DUPLICATE KEY UPDATE offset = values(offset)
        |
        |
        |""".stripMargin

    var connection: Connection = null
    var ps1: PreparedStatement =  null
    var ps2: PreparedStatement =  null
    try {
      connection = JDBCUtil.getConnection()

      //手动去控制事务的提交
      connection.setAutoCommit(false)

      ps1= connection.prepareStatement(sql1)
      ps2= connection.prepareStatement(sql2)

      for (((date, hour), totalamount) <- result) {

        ps1.setString(1,date)
        ps1.setString(2,hour)
        ps1.setDouble(3,totalamount)

        //攒起来
        ps1.addBatch()

      }

      for (offsetRange <- ranges) {

        ps2.setString(1,groupId)
        ps2.setString(2,topic)
        ps2.setInt(3,offsetRange.partition)
        ps2.setLong(4,offsetRange.untilOffset)

        ps2.addBatch()

      }

      //执行写出
      val res1: Array[Int] = ps1.executeBatch()

      val res2: Array[Int] = ps2.executeBatch()

      //提交事务
      connection.commit()

      println("当前写出数据:"+res1.size)
      println("当前写出分区偏移量:"+res2.size)


    } catch {
      case e:Exception => {
        //回滚事务
        connection.rollback()
        e.printStackTrace()
        throw new RuntimeException("查询偏移量失败!")
      }
    } finally {

      if (ps1 != null){
        ps1.close()
      }

      if (ps2 != null){
        ps2.close()
      }

      if (connection != null){
        connection.close()
      }

    }

  }


  def main(args: Array[String]): Unit = {

    context=new StreamingContext("local[*]",appName,Seconds(batchDuration))

    runApp{

      //到数据库查询上一次这个组消费这个主题的offset
      val offsetsMap: Map[TopicPartition, Long] = selectHistoryOffsetsFromMysql(groupId, topic)

      val ds: InputDStream[ConsumerRecord[String, String]] = DStreamUtil.createDS(context, groupId, topic,true,offsetsMap)

      ds.foreachRDD(rdd => {

        if (!rdd.isEmpty()){

          val ranges: Array[OffsetRange] = rdd.asInstanceOf[HasOffsetRanges].offsetRanges

          //计算
          val result: Array[((String, String), Double)] = rdd.map(record => {

            val orderInfo: OrderInfo = JSON.parseObject(record.value(), classOf[OrderInfo])

            //封装额外的时间字段
            orderInfo.create_date = DateHandleUtil.parseDateTimeToDate(orderInfo.create_time)
            orderInfo.create_hour = DateHandleUtil.parseDateTimeToHour(orderInfo.create_time)

            ((orderInfo.create_date, orderInfo.create_hour), orderInfo.total_amount)

          }).reduceByKey(_ + _).collect()

          //事务输出
          writeResultAndOffsetsInCommonTranscation(result,ranges)

        }

      })

    }

  }
}
